# Custom Statistic

In this notebook, we'll implement a new statistic as a subclass of LinearFractionalStatistic. All fairrets work with any LinearFractionalStatistic, so they will work with out new statistic as well.

We take inspiration from the paper "[Generalizing Group Fairness in Machine Learning via Utilities](https://jair.org/index.php/jair/article/view/14238/26985)" by Blandin and Kash. For the well-known German Credit Dataset, they propose the following cost for predictions $\hat{Y}$ and ground truth labels $Y$:
$$C = \begin{cases}
 0 & \text{ if } \hat{Y} = Y \\
 1 & \text{ if } \hat{Y} = 0 \wedge Y = 1 \\
 5 & \text{ if } \hat{Y} = 1 \wedge Y = 0
\end{cases}$$

The costs are motivated by the fact that a loan applicant that receives a loan $(\hat{Y} = 1)$ but will not repay it $(Y = 0)$ will have to default, which is considered far worse than when an applicant is rejected $(\hat{Y} = 0)$ that would have repaid $(Y = 1)$ the loan.

The statistic in this case is the average cost $C$ incurred over all individuals in a sensitive group. Hence, the statistic is canonically formalized as
$$\gamma(k, f) = \frac{\mathbb{E}[SC]}{\mathbb{E}[S]} = \frac{\mathbb{E}[S(1 Y(1 - f(X)) + 5 (1 - Y)f(X))]}{\mathbb{E}[S]} = \frac{\mathbb{E}[S(Y + (5 - 6Y)f(X))]}{\mathbb{E}[S]}$$
where we filled in $\hat{Y}$ with the probabilistic $f(X)$.

The canonical form allows us to identify how the statistic has a linear-fractional form with respect to $f$. Ignoring $S$ for a moment, the intercept of the numerator is $Y$ and the slope is $(5 - 6Y)$. The denominator is not dependent on $f$.

The statistic is then implemented as:

In [1]:
import torch
from fairret.statistic import LinearFractionalStatistic

class CustomCost(LinearFractionalStatistic):
 def num_intercept(self, label: torch.Tensor) -> torch.Tensor:
 return label

 def num_slope(self, label: torch.Tensor) -> torch.Tensor:
 return 5 - 6 * label

 def denom_intercept(self, label: torch.Tensor) -> torch.Tensor:
 return 1

 def denom_slope(self, label: torch.Tensor) -> torch.Tensor:
 return 0.

Let's quickly try it out...

In [2]:
import torch
torch.manual_seed(0)

feat = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]])
sens = torch.tensor([[1., 0.], [1., 0.], [0., 1.], [0., 1.]])
label = torch.tensor([[0.], [1.], [0.], [1.]])

from fairret.loss import NormLoss

statistic = CustomCost()
norm_loss = NormLoss(statistic)

h_layer_dim = 16
lr = 1e-3
batch_size = 1024

def build_model():
 model = torch.nn.Sequential(
 torch.nn.Linear(feat.shape[1], h_layer_dim),
 torch.nn.ReLU(),
 torch.nn.Linear(h_layer_dim, 1)
 )
 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
 return model, optimizer

from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(feat, sens, label)
dataloader = DataLoader(dataset, batch_size=batch_size)

Without fairret...

In [3]:
import numpy as np

nb_epochs = 100
model, optimizer = build_model()
for epoch in range(nb_epochs):
 losses = []
 for batch_feat, batch_sens, batch_label in dataloader:
 optimizer.zero_grad()
 
 logit = model(batch_feat)
 loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)
 loss.backward()
 
 optimizer.step()
 losses.append(loss.item())
 print(f"Epoch: {epoch}, loss: {np.mean(losses)}")
 
pred = torch.sigmoid(model(feat))
stat_per_group = statistic(pred, sens, label)
absolute_diff = torch.abs(stat_per_group[0] - stat_per_group[1])

print(f"The {statistic.__class__.__name__} for group 0 is {stat_per_group[0]}")
print(f"The {statistic.__class__.__name__} for group 1 is {stat_per_group[1]}")
print(f"The absolute difference is {torch.abs(stat_per_group[0] - stat_per_group[1])}")

Epoch: 0, loss: 0.7091795206069946
Epoch: 1, loss: 0.7061765193939209
Epoch: 2, loss: 0.7033581733703613
Epoch: 3, loss: 0.7007156610488892
Epoch: 4, loss: 0.6982340812683105
Epoch: 5, loss: 0.6959078907966614
Epoch: 6, loss: 0.6937355995178223
Epoch: 7, loss: 0.6917158365249634
Epoch: 8, loss: 0.6898466944694519
Epoch: 9, loss: 0.6881252527236938
Epoch: 10, loss: 0.6865478754043579
Epoch: 11, loss: 0.6851094961166382
Epoch: 12, loss: 0.6838041543960571
Epoch: 13, loss: 0.6826250553131104
Epoch: 14, loss: 0.6815641522407532
Epoch: 15, loss: 0.6806124448776245
Epoch: 16, loss: 0.6797604560852051
Epoch: 17, loss: 0.6789975762367249
Epoch: 18, loss: 0.6783132553100586
Epoch: 19, loss: 0.6776963472366333
Epoch: 20, loss: 0.6771360039710999
Epoch: 21, loss: 0.6766215562820435
Epoch: 22, loss: 0.6761429309844971
Epoch: 23, loss: 0.6756909489631653
Epoch: 24, loss: 0.6752569675445557
Epoch: 25, loss: 0.6748337745666504
Epoch: 26, loss: 0.674415111541748
Epoch: 27, loss: 0.673996090888977
Epoc

With fairret...

In [4]:
import numpy as np

nb_epochs = 100
fairness_strength = 1
model, optimizer = build_model()
for epoch in range(nb_epochs):
 losses = []
 for batch_feat, batch_sens, batch_label in dataloader:
 optimizer.zero_grad()
 
 logit = model(batch_feat)
 loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)
 loss += fairness_strength * norm_loss(logit, batch_sens, batch_label)
 loss.backward()
 
 optimizer.step()
 losses.append(loss.item())
 print(f"Epoch: {epoch}, loss: {np.mean(losses)}")
 
pred = torch.sigmoid(model(feat))
stat_per_group = statistic(pred, sens, label)
absolute_diff = torch.abs(stat_per_group[0] - stat_per_group[1])

print(f"The {statistic.__class__.__name__} for group 0 is {stat_per_group[0]}")
print(f"The {statistic.__class__.__name__} for group 1 is {stat_per_group[1]}")
print(f"The absolute difference is {torch.abs(stat_per_group[0] - stat_per_group[1])}")

Epoch: 0, loss: 0.7234874963760376
Epoch: 1, loss: 0.7193881869316101
Epoch: 2, loss: 0.7153821587562561
Epoch: 3, loss: 0.7114719748497009
Epoch: 4, loss: 0.7076587677001953
Epoch: 5, loss: 0.703943133354187
Epoch: 6, loss: 0.7003254294395447
Epoch: 7, loss: 0.6968045234680176
Epoch: 8, loss: 0.6933779120445251
Epoch: 9, loss: 0.7009283900260925
Epoch: 10, loss: 0.70442134141922
Epoch: 11, loss: 0.7051329016685486
Epoch: 12, loss: 0.7039238810539246
Epoch: 13, loss: 0.7013260126113892
Epoch: 14, loss: 0.6976962685585022
Epoch: 15, loss: 0.693289577960968
Epoch: 16, loss: 0.6954131722450256
Epoch: 17, loss: 0.6971543431282043
Epoch: 18, loss: 0.6984840035438538
Epoch: 19, loss: 0.6994330883026123
Epoch: 20, loss: 0.7000323534011841
Epoch: 21, loss: 0.700312077999115
Epoch: 22, loss: 0.7003009915351868
Epoch: 23, loss: 0.7000272870063782
Epoch: 24, loss: 0.699517011642456
Epoch: 25, loss: 0.6987953782081604
Epoch: 26, loss: 0.6978861689567566
Epoch: 27, loss: 0.6968110203742981
Epoch: 2